#!/usr/bin/env python  
#-*- coding:utf-8 _*-  
""" 
@author:hello_life 
@license: Apache Licence 
@file: main.py 
@time: 2022/04/23
@software: PyCharm 
description:
"""
import torch
from torch.utils.data import DataLoader

from model.albert_model import Albert_model
from parameters.albert_config import Config
from utils.train_utils import train_loop,evaluate_loop
from utils.imdb_data import IMDBDataset,collate_fn

if __name__ == '__main__':
    #加载参数
    config=Config()

    #加载data
    train_dataset=IMDBDataset(config.train_path)
    test_dataset=IMDBDataset(config.test_path)
    val_dataset=IMDBDataset(config.val_path)

    train_dataloader=DataLoader(train_dataset,shuffle=True,batch_size=config.batch_size,collate_fn=collate_fn)
    test_dataloader=DataLoader(test_dataset,shuffle=False,batch_size=config.batch_size,collate_fn=collate_fn)
    val_dataloader=DataLoader(train_dataset,shuffle=False,batch_size=config.batch_size,collate_fn=collate_fn)

    #加载模型
    model=Albert_model(config).to(config.device)

    #训练
    train_loop(model,train_dataloader,test_dataloader)

    #测试
    # evaluate_loop(model,val_dataloader)